#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

"""
Main entry point for the FastAPI OpenAPI generator.
"""

import copy
from pathlib import Path
from typing import Any

import yaml
from fastapi.openapi.utils import get_openapi

from . import app, schema_collection, schema_filtering, schema_transforms, state


def generate_openapi_spec(output_dir: str) -> dict[str, Any]:
    """
    Generate OpenAPI specification using FastAPI's built-in method.

    Args:
        output_dir: Directory to save the generated files

    Returns:
        The generated OpenAPI specification as a dictionary
    """
    state.reset_generator_state()
    # Create the FastAPI app
    fastapi_app = app.create_llama_stack_app()

    # Generate the OpenAPI schema
    openapi_schema = get_openapi(
        title=fastapi_app.title,
        version=fastapi_app.version,
        description=fastapi_app.description,
        routes=fastapi_app.routes,
        servers=fastapi_app.servers,
    )

    # Set OpenAPI version to 3.1.0
    openapi_schema["openapi"] = "3.1.0"

    # Add standard error responses
    openapi_schema = schema_transforms._add_error_responses(openapi_schema)

    # Ensure all @json_schema_type decorated models are included
    openapi_schema = schema_collection._ensure_json_schema_types_included(openapi_schema)

    # Fix $ref references to point to components/schemas instead of $defs
    openapi_schema = schema_transforms._fix_ref_references(openapi_schema)

    # Fix path parameter resolution issues
    openapi_schema = schema_transforms._fix_path_parameters(openapi_schema)

    # Eliminate $defs section entirely for oasdiff compatibility
    openapi_schema = schema_transforms._eliminate_defs_section(openapi_schema)

    # Clean descriptions in schema definitions by removing docstring metadata
    openapi_schema = schema_transforms._clean_schema_descriptions(openapi_schema)
    openapi_schema = schema_transforms._normalize_empty_responses(openapi_schema)

    # Remove query parameters from POST/PUT/PATCH endpoints that have a request body
    # FastAPI sometimes infers parameters as query params even when they should be in the request body
    openapi_schema = schema_transforms._remove_query_params_from_body_endpoints(openapi_schema)

    # Add x-llama-stack-extra-body-params extension for ExtraBodyField parameters
    openapi_schema = schema_transforms._add_extra_body_params_extension(openapi_schema)

    # Remove request bodies from GET endpoints (GET requests should never have request bodies)
    # This must run AFTER _add_extra_body_params_extension to ensure any request bodies
    # that FastAPI incorrectly added to GET endpoints are removed
    openapi_schema = schema_transforms._remove_request_bodies_from_get_endpoints(openapi_schema)

    # Extract duplicate union types to shared schema references
    openapi_schema = schema_transforms._extract_duplicate_union_types(openapi_schema)

    # Split into stable (v1 only), experimental (v1alpha + v1beta), deprecated, and combined (stainless) specs
    # Each spec needs its own deep copy of the full schema to avoid cross-contamination
    stable_schema = schema_filtering._filter_schema_by_version(
        copy.deepcopy(openapi_schema), stable_only=True, exclude_deprecated=True
    )
    experimental_schema = schema_filtering._filter_schema_by_version(
        copy.deepcopy(openapi_schema), stable_only=False, exclude_deprecated=True
    )
    deprecated_schema = schema_filtering._filter_deprecated_schema(copy.deepcopy(openapi_schema))
    combined_schema = schema_filtering._filter_combined_schema(copy.deepcopy(openapi_schema))

    # Apply duplicate union extraction to combined schema (used by Stainless)
    combined_schema = schema_transforms._extract_duplicate_union_types(combined_schema)

    base_description = (
        "This is the specification of the Llama Stack that provides\n"
        "                    a set of endpoints and their corresponding interfaces that are\n"
        "    tailored to\n"
        "                    best leverage Llama Models."
    )

    schema_configs = [
        (
            stable_schema,
            "Llama Stack Specification",
            "**✅ STABLE**: Production-ready APIs with backward compatibility guarantees.",
        ),
        (
            experimental_schema,
            "Llama Stack Specification - Experimental APIs",
            "**🧪 EXPERIMENTAL**: Pre-release APIs (v1alpha, v1beta) that may change before\n    becoming stable.",
        ),
        (
            deprecated_schema,
            "Llama Stack Specification - Deprecated APIs",
            "**⚠️ DEPRECATED**: Legacy APIs that may be removed in future versions. Use for\n    migration reference only.",
        ),
        (
            combined_schema,
            "Llama Stack Specification - Stable & Experimental APIs",
            "**🔗 COMBINED**: This specification includes both stable production-ready APIs\n    and experimental pre-release APIs. Use stable APIs for production deployments\n    and experimental APIs for testing new features.",
        ),
    ]

    for schema, title, description_suffix in schema_configs:
        if "info" not in schema:
            schema["info"] = {}
        schema["info"].update(
            {
                "title": title,
                "version": "v1",
                "description": f"{base_description}\n\n    {description_suffix}",
            }
        )

    schemas_to_validate = [
        (stable_schema, "Stable schema"),
        (experimental_schema, "Experimental schema"),
        (deprecated_schema, "Deprecated schema"),
        (combined_schema, "Combined (stainless) schema"),
    ]

    for schema, _ in schemas_to_validate:
        schema_transforms._fix_schema_issues(schema)
        schema_transforms._apply_legacy_sorting(schema)

    print("\nValidating generated schemas...")
    failed_schemas = [
        name for schema, name in schemas_to_validate if not schema_transforms.validate_openapi_schema(schema, name)
    ]
    if failed_schemas:
        raise ValueError(f"Invalid schemas: {', '.join(failed_schemas)}")

    # Ensure output directory exists
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # Save the stable specification
    yaml_path = output_path / "llama-stack-spec.yaml"
    schema_transforms._write_yaml_file(yaml_path, stable_schema)
    # Post-process the YAML file to remove $defs section and fix references
    with open(yaml_path) as f:
        yaml_content = f.read()

    if "  $defs:" in yaml_content or "#/$defs/" in yaml_content:
        # Use string replacement to fix references directly
        if "#/$defs/" in yaml_content:
            yaml_content = yaml_content.replace("#/$defs/", "#/components/schemas/")

        # Parse the YAML content
        yaml_data = yaml.safe_load(yaml_content)

        # Move $defs to components/schemas if it exists
        if "$defs" in yaml_data:
            if "components" not in yaml_data:
                yaml_data["components"] = {}
            if "schemas" not in yaml_data["components"]:
                yaml_data["components"]["schemas"] = {}

            # Move all $defs to components/schemas
            for def_name, def_schema in yaml_data["$defs"].items():
                yaml_data["components"]["schemas"][def_name] = def_schema

            # Remove the $defs section
            del yaml_data["$defs"]

        # Write the modified YAML back
        schema_transforms._write_yaml_file(yaml_path, yaml_data)

    print(f"Generated YAML (stable): {yaml_path}")

    experimental_yaml_path = output_path / "experimental-llama-stack-spec.yaml"
    schema_transforms._write_yaml_file(experimental_yaml_path, experimental_schema)
    print(f"Generated YAML (experimental): {experimental_yaml_path}")

    deprecated_yaml_path = output_path / "deprecated-llama-stack-spec.yaml"
    schema_transforms._write_yaml_file(deprecated_yaml_path, deprecated_schema)
    print(f"Generated YAML (deprecated): {deprecated_yaml_path}")

    # Generate combined (stainless) spec
    stainless_yaml_path = output_path / "stainless-llama-stack-spec.yaml"
    schema_transforms._write_yaml_file(stainless_yaml_path, combined_schema)
    print(f"Generated YAML (stainless/combined): {stainless_yaml_path}")

    return stable_schema


def main():
    """Main entry point for the FastAPI OpenAPI generator."""
    import argparse

    parser = argparse.ArgumentParser(description="Generate OpenAPI specification using FastAPI")
    parser.add_argument("output_dir", help="Output directory for generated files")

    args = parser.parse_args()

    print("Generating OpenAPI specification using FastAPI...")
    print(f"Output directory: {args.output_dir}")

    try:
        openapi_schema = generate_openapi_spec(output_dir=args.output_dir)

        print("\nOpenAPI specification generated successfully!")
        print(f"Schemas: {len(openapi_schema.get('components', {}).get('schemas', {}))}")
        print(f"Paths: {len(openapi_schema.get('paths', {}))}")
        operation_count = sum(
            1
            for path_info in openapi_schema.get("paths", {}).values()
            for method in ["get", "post", "put", "delete", "patch"]
            if method in path_info
        )
        print(f"Operations: {operation_count}")

    except Exception as e:
        print(f"Error generating OpenAPI specification: {e}")
        raise


if __name__ == "__main__":
    main()
